source('codes/00_import_libraries.R')

win      = 120
beg_year = 1881


#------------------------------------------------------------------------------#
# data -------------------------------------------------------------------------
#------------------------------------------------------------------------------#

dt = readRDS('input/combined_data.rds')[['base']]

recess = read_excel('input/us recession dummy.xlsx') %>% 
  mutate(
    date = as.Date(as.yearmon(ym)),
    xmin = date,
    xmax = lead(date),
    ymin = -0.6,
    ymax = 0.2
  ) %>% 
  filter(as.yearmon(date) >= 'Jan 1881')



#------------------------------------------------------------------------------#
# function ---------------------------------------------------------------------
#------------------------------------------------------------------------------#

fn_cum_os_r2 = function(df, year) {
  # compute the difference in cumulative sse of historical y and predicted y
  # a value > 0 indicates that the predictive y outperforms upto that point
  
  dt = df %>% filter(lubridate::year(ym) >= year)
  
  cum_r2 = (cumsum((dt$y_test - dt$y_bar)^2) - cumsum((dt$y_test - dt$y_hat)^2)) /
    sum((dt$y_test - dt$y_bar)^2)
  dt$cum_r2 = cum_r2
  
  return(dt)
}



#------------------------------------------------------------------------------#
# apply ------------------------------------------------------------------------
#------------------------------------------------------------------------------#

ols = readRDS(glue('output/os/ols_mkt1_{win}.rds'))[['War']] %>% 
  fn_cum_os_r2(beg_year) %>% 
  select(ym, War_r2 = cum_r2)

pls = readRDS(glue('output/os//pls_mkt1_{win}.rds'))[['topic']] %>% 
  fn_cum_os_r2(beg_year) %>% 
  select(ym, PLS_r2 = cum_r2)


gplot = list(ols, pls) %>% 
  reduce(inner_join, by = 'ym') %>% 
  pivot_longer(
    !ym, names_to  = 'var', values_to = 'val'
  ) %>% 
  
  mutate(
    ym   = as.yearmon(ym),
    date = as.Date(ym) , 
    val  = val * 100,
    var  = factor(
      var, levels = c('War_r2', 'PLS_r2'), labels = c('War', 'PLS')
    ),
    lab = if_else(ym == max(ym), as.character(var), '')
  ) %>% 
  
  ggplot(aes(x = date, y = val, linetype = var)) +
  
  geom_rect(
    data = recess, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = as.factor(recess)), 
    inherit.aes = FALSE, alpha = 1
  ) +
  scale_fill_manual(values = c('white', 'grey')) +
  
  geom_line(linewidth = 1) +
  
  labs(title = '', y = '%', x = '') +
  scale_y_continuous(breaks = pretty_breaks()) +
  scale_x_date(
    # date_breaks = '10 years'
    breaks = seq.Date(as.Date('1880-01-01'), as.Date('2020-01-01'), by = '10 years'), 
    date_labels = '%Y', expand = c(0.05, 0)
  ) +
  
  geom_text_repel(
    aes(label = lab), max.overlaps = 300, size = 7, fontface = 'plain', nudge_x = 10
  ) +
  
  theme_base() +
  theme(
    legend.position = 'none',
    text            = element_text(size = 25),
    plot.title      = element_text(hjust = 0.5),
    plot.background = element_rect(colour = NA)
  )


ggsave(glue('output/figures/Figure_5B.tiff'), gplot, width = 12, height = 6, device='tiff', dpi=300)












